import copy
import torchvision
import torchvision.transforms as transforms
import torch
import torch.utils.data
from torch import nn
import torchattacks
from tqdm import tqdm
import logging
import random
import os,sys
import numpy as np
import argparse
import loss_functions
from torch.optim.lr_scheduler import MultiStepLR
import time
from datetime import timedelta
from logging import getLogger
import utils
from models.resnet import resnet18
import dataloader
from functorch.experimental import replace_all_batch_norm_modules_

parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet',
					help='model architecture')
parser.add_argument('--dataset', default='cifar10', type=str,
					help='which dataset used to train')
parser.add_argument('--num_classes', default=10, type=int, metavar='N',
					help='number of classes')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
					help='number of total epochs to run')
parser.add_argument('-b', '--batch_size', default=128, type=int,
					metavar='N',
					help='mini-batch size (default: 256), this is the total '
						 'batch size of all GPUs on the current node when '
						 'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
					metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
					help='momentum of SGD solver')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float,
					metavar='W', help='weight decay (default: 1e-4)',
					dest='wd')
parser.add_argument('--save', default='fgsm.pkl', type=str,
					help='model save name')
parser.add_argument('--seed', type=int,
					default=0, help='random seed')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--eps', type=float, default=8./255., help='perturbation bound')
parser.add_argument('--ns', type=int, default=10, help='maximum perturbation step K')
parser.add_argument('--ss', type=float, default=2./255., help='step size')
parser.add_argument('--beta', type=float, default=6.0)


parser.add_argument('--exp', default='fgsm', type=str,
					help='exp name')
parser.add_argument('--method', default='fgsm', type=str,
					help='AT method to use')

#TE Settings
parser.add_argument('--te-alpha', default=0.9, type=float,
					help='momentum term of self-adaptive training')
parser.add_argument('--start-es', default=90, type=int,
					help='start epoch of self-adaptive training (default 0)')
parser.add_argument('--end-es', default=150, type=int,
					help='start epoch of self-adaptive training (default 0)')
parser.add_argument('--reg-weight', default=300, type=float)

parser.add_argument('--pgd', default=90, type=int)

args = parser.parse_args()

if args.dataset == 'cifar10':
	args.num_classes = 10
else:
	args.num_classes = 100

def sigmoid_rampup(current, start_es, end_es):
	"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
	if current < start_es:
		return 0.0
	if current > end_es:
		return 1.0
	else:
		import math
		phase = 1.0 - (current - start_es) / (end_es - start_es)
		return math.exp(-5.0 * phase * phase)

#os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
utils.setup_seed(args.seed)


logger = getLogger()
if not os.path.exists(args.dataset+'/'+ args.arch +'/'+args.exp):
	os.makedirs(args.dataset+'/'+ args.arch +'/'+args.exp)
logger = utils.create_logger(
	os.path.join(args.dataset+'/'+ args.arch +'/'+args.exp + '/', args.exp + ".log"), rank=0
)
logger.info("============ Initialized logger ============")
logger.info(
	"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
)
args.save = args.dataset+'/'+ args.arch +'/'+args.exp + '/' +  args.save


wd=args.wd
learning_rate=args.lr
epochs=args.epochs
batch_size=args.batch_size
torch.backends.cudnn.benchmark = True

transform=transforms.Compose([transforms.RandomCrop(32, padding=4),
							  transforms.RandomHorizontalFlip(),
							  torchvision.transforms.ToTensor(),
							  ])
transform_test=transforms.Compose([torchvision.transforms.Resize((32,32)),
								   transforms.ToTensor(),
								   ])

data = dataloader.Data(args.dataset, './data')
trainset, testset = data.data_loader(transform, transform_test)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                shuffle=True, drop_last=False, num_workers=0)

test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                                shuffle=False, drop_last=False, num_workers=0)

if args.arch == 'resnet':
	n = resnet18(num_classes=args.num_classes).cuda()




optimizer = torch.optim.SGD(n.parameters(),momentum=args.momentum,
							lr=learning_rate,weight_decay=wd)


milestones = [int(args.epochs * 0.5), int(args.epochs * 0.75)]

scheduler = MultiStepLR(optimizer,milestones=milestones,gamma=args.gamma)

data_size = len(trainset.data)
targets = np.asarray(trainset.targets)
num_classes = args.num_classes
targets_np = np.asarray(targets)

if args.method == 'te':
	pgd_te = loss_functions.PGD_TE_new(num_samples=data_size,
					   num_classes=args.num_classes,
					   momentum=args.te_alpha,
					   step_size=args.ss,
					   epsilon=args.eps,
					   perturb_steps=args.ns,
					   norm='linf',
					   es=args.start_es,
					   pgd=args.pgd)

train_clean_acc = []
train_adv_acc = []
test_clean_acc = []
test_adv_acc = []
best_eval_acc = 0.0

for epoch in range(epochs):
	rampup_rate = sigmoid_rampup(epoch+1, args.start_es, args.end_es)
	weight = rampup_rate * args.reg_weight

	loadertrain = tqdm(train_loader, desc='{} E{:03d}'.format('train', epoch), ncols=0)
	epoch_loss = 0.0
	total = 0.0
	clean_acc = 0.0
	adv_acc = 0.0
	for x_train, y_train, idx in loadertrain:
		n.eval()
		x_train, y_train = x_train.cuda(), y_train.cuda()
		y_pre = n(x_train)
		if args.method == 'pgd':
			logits_adv, loss = loss_functions.PGD_new(n, x_train, y_train, optimizer, epoch+1, args)
		elif args.method == 'fgsm':
			logits_adv, loss = loss_functions.FGSM(n, x_train, y_train, optimizer, args)
		elif args.method == 'te':
			logits_adv, loss = pgd_te(x_train, y_train, idx, epoch+1, n, optimizer, weight)
		elif args.method == 'trades':
			logits_adv, loss = loss_functions.TRADES(n, x_train, y_train, optimizer, args)

		loss.backward()
		optimizer.step()
		epoch_loss += loss.data.item()
		_, predicted = torch.max(y_pre.data, 1)
		_, predictedadv = torch.max(logits_adv.data, 1)
		total += y_train.size(0)
		clean_acc += predicted.eq(y_train.data).cuda().sum()
		adv_acc += predictedadv.eq(y_train.data).cuda().sum()
		fmt = '{:.4f}'.format
		loadertrain.set_postfix(loss=fmt(loss.data.item()),
								acc_cl=fmt(clean_acc.item() / total * 100),
								acc_adv=fmt(adv_acc.item() / total * 100))
	train_clean_acc.append(clean_acc.item() / total * 100)
	train_adv_acc.append(adv_acc.item() / total * 100)
	scheduler.step()

	if (epoch) % 1 == 0:
		Loss_test = nn.CrossEntropyLoss().cuda()
		test_loss_cl = 0.0
		test_loss_adv = 0.0
		correct_cl = 0.0
		correct_adv = 0.0
		total = 0.0
		n.eval()
		pgd_eval = torchattacks.PGD(n, eps=8.0/255.0, steps=20)
		loadertest = tqdm(test_loader, desc='{} E{:03d}'.format('test', epoch), ncols=0)
		with torch.enable_grad():
			for x_test, y_test, idx in loadertest:
				x_test, y_test = x_test.cuda(), y_test.cuda()
				x_adv = pgd_eval(x_test, y_test)
				n.eval()
				y_pre = n(x_test)
				y_adv = n(x_adv)
				loss_cl = Loss_test(y_pre, y_test)
				loss_adv = Loss_test(y_adv, y_test)
				test_loss_cl += loss_cl.data.item()
				test_loss_adv += loss_adv.data.item()
				_, predicted = torch.max(y_pre.data, 1)
				_, predicted_adv = torch.max(y_adv.data, 1)
				total += y_test.size(0)
				correct_cl += predicted.eq(y_test.data).cuda().sum()
				correct_adv += predicted_adv.eq(y_test.data).cuda().sum()
				fmt = '{:.4f}'.format
				loadertest.set_postfix(loss_cl=fmt(loss_cl.data.item()),
									   loss_adv=fmt(loss_adv.data.item()),
									   acc_cl=fmt(correct_cl.item() / total * 100),
									   acc_adv=fmt(correct_adv.item() / total * 100))
			test_clean_acc.append(correct_cl.item() / total * 100)
			test_adv_acc.append(correct_adv.item() / total * 100)
		if correct_adv.item() / total * 100 > best_eval_acc:
			best_eval_acc = correct_adv.item() / total * 100
			checkpoint = {
					'state_dict': n.state_dict(),
					'epoch': epoch
				}
			torch.save(checkpoint, args.save+ 'best.pkl')
		if (epoch + 1) % 1 == 0:
			checkpoint = {
					'state_dict': n.state_dict(),
					'epoch': epoch
				}
			torch.save(checkpoint, args.save + '%d.pkl'%(epoch+1))
checkpoint = {
			'state_dict': n.state_dict(),
			'epoch': epoch
		}
torch.save(checkpoint, args.save + 'last.pkl')
np.save(args.save+'_train_acc_cl.npy', train_clean_acc)
np.save(args.save+'_train_acc_adv.npy', train_adv_acc)
np.save(args.save+'_test_acc_cl.npy', test_clean_acc)
np.save(args.save+'_test_acc_adv.npy', test_adv_acc)